#!/usr/bin/env python
# -*- coding: utf-8 -*-
# vim: tabstop=4 shiftwidth=4 softtabstop=4
#
# Copyright (C) 2021 GEM Foundation
#
# OpenQuake is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# OpenQuake is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with OpenQuake. If not, see <http://www.gnu.org/licenses/>.
import os
import ast
from openquake.baselib import performance
from openquake.commonlib import datastore


def vm_xml(loss_type, vfs):
    return f'''\
<?xml version="1.0" encoding="UTF-8"?>
<nrml xmlns="http://openquake.org/xmlns/nrml/0.5">
<vulnerabilityModel id="vuln_model" assetCategory="buildings" lossCategory="{loss_type}">

{vfs}
</vulnerabilityModel>
</nrml>
'''


def vf_xml(dic):
    dic['imls'] = ' '.join(map(str, dic['imls']))
    dic['covs'] = ' '.join(map(str, dic['covs']))
    dic['means'] = ' '.join(map(str, dic['mean_loss_ratios']))
    templ = '''\
<vulnerabilityFunction id="{id}" dist="{distribution_name}">
    <imls imt="{imt}"> {imls} </imls>
    <meanLRs> {means} </meanLRs>
    <covLRs> {covs} </covLRs>
</vulnerabilityFunction>
'''
    return templ.format(**dic)


def main(calc_id: int = -1, extract_dir='.'):
    """
    Extract vulnerability models from the datastore in XML format
    """
    with performance.Monitor('extract', measuremem=True) as mon:
        crm_df = datastore.read(calc_id).read_df('crm')
        for loss_type, df in crm_df.groupby('loss_type'):
            loss_type = loss_type.decode('ascii')
            vfs = []
            for riskid, riskfunc in zip(df.riskid, df.riskfunc):
                rf = riskfunc.decode('ascii')
                [(clsname, dic)] = ast.literal_eval(rf).items()
                if clsname.endswith('VulnerabilityFunction'):
                    vfs.append(vf_xml(dic))
            if vfs:
                fname = os.path.join(
                    extract_dir, '%s_vulnerability.xml' % loss_type)
                with open(fname, 'w', encoding='utf-8') as f:
                    xml = vm_xml(loss_type, '\n'.join(vfs))
                    f.write(xml)
                print('Saved', fname)
    if mon.duration > 1:
        print(mon)


main.calc_id = 'number of the calculation'
main.extract_dir = 'directory where to extract'
